"""
Extract sequence representation from esm2 for linker tuning with esmfold
"""
import os
import re
import pickle
import argparse
import typing as T

import torch
import torch.nn as nn
from torch.nn import LayerNorm

from openfold.np import residue_constants

from esm.pretrained import load_model_and_alphabet_local
from esm.esmfold.v1.misc import (
    batch_encode_sequences,
    collate_dense_tensors,
)
from esm.esmfold.v1.esmfold import ESMFold

from utils.data_util import get_seqs
from utils.crop import crop
from utils.logger import Logger
logger = Logger.logger



class ESM2_Extractor(ESMFold):
    def __init__(self, esm_location, esmfold_config=None, **kwargs):
        nn.Module.__init__(self)     
        self.cfg = esmfold_config if esmfold_config else ESMFoldConfig(**kwargs)
        cfg = self.cfg

        self.distogram_bins = 64

        self.esm, self.esm_dict = load_model_and_alphabet_local(esm_location)

        self.esm.requires_grad_(False)
        self.esm.half()

        self.esm_feats = self.esm.embed_dim
        self.esm_attns = self.esm.num_layers * self.esm.attention_heads
        self.register_buffer("af2_to_esm", ESMFold._af2_to_esm(self.esm_dict))
        self.esm_s_combine = nn.Parameter(torch.zeros(self.esm.num_layers + 1))

        c_s = cfg.trunk.sequence_state_dim
        c_z = cfg.trunk.pairwise_state_dim

        self.esm_s_mlp = nn.Sequential(
            LayerNorm(self.esm_feats),
            nn.Linear(self.esm_feats, c_s),
            nn.ReLU(),
            nn.Linear(c_s, c_s),
        )

        # 0 is padding, N is unknown residues, N + 1 is mask.
        self.n_tokens_embed = residue_constants.restype_num + 3
        self.pad_idx = 0
        self.unk_idx = self.n_tokens_embed - 2
        self.mask_idx = self.n_tokens_embed - 1
       
        

    def forward(
        self,
        aa: torch.Tensor,
        mask: T.Optional[torch.Tensor] = None,
        residx: T.Optional[torch.Tensor] = None,
        masking_pattern: T.Optional[torch.Tensor] = None,
    ):
        """Runs a forward pass given input tokens. Use `model.infer` to
        run inference from a sequence.

        Args:
            aa (torch.Tensor): Tensor containing indices corresponding to amino acids. Indices match
                openfold.np.residue_constants.restype_order_with_x.
            mask (torch.Tensor): Binary tensor with 1 meaning position is unmasked and 0 meaning position is masked.
            masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
                as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
                different masks are provided.
        """

        if mask is None:
            mask = torch.ones_like(aa)

        B = aa.shape[0]
        L = aa.shape[1]
        device = aa.device

        if residx is None:
            residx = torch.arange(L, device=device).expand_as(aa)

        # === ESM ===
        esmaa = self._af2_idx_to_esm_idx(aa, mask)

        if masking_pattern is not None:
            esmaa = self._mask_inputs_to_esm(esmaa, masking_pattern)

        esm_s = self._compute_language_model_representations(esmaa)

        # Convert esm_s to the precision used by the trunk and
        # the structure module. These tensors may be a lower precision if, for example,
        # we're running the language model in fp16 precision.
        esm_s = esm_s.to(self.esm_s_combine.dtype)

        esm_s = esm_s.detach()

        # === preprocessing ===
        esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) 

        s_s_0 = self.esm_s_mlp(esm_s)    # seq representation, of shape (1, seq_len, 1024)
       
        return s_s_0

    @torch.no_grad()
    def infer(
        self,
        sequences: T.Union[str, T.List[str]],
        residx: T.Optional[torch.Tensor] = None,
        masking_pattern: T.Optional[torch.Tensor] = None,
        residue_index_offset: T.Optional[int] = 0,
        chain_linker: T.Optional[str] = "G" * 25,
    ):
        """Runs a forward pass given input sequences.

        Args:
            sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in,
                each chain should be separated by a ':' token (e.g. "<chain1>:<chain2>:<chain3>").
            masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
                as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
                different masks are provided.
            chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain
                predictions. Default: length-25 poly-G ("G" * 25).
        """
        if isinstance(sequences, str):
            sequences = [sequences]

        aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences(
            sequences, residue_index_offset, chain_linker
        )

        if residx is None:
            residx = _residx
        elif not isinstance(residx, torch.Tensor):
            residx = collate_dense_tensors(residx)

        aatype, mask, residx, linker_mask = map(
            lambda x: x.to(self.device), (aatype, mask, residx, linker_mask)
        )

        output = self.forward(
            aatype,
            mask=mask,
            masking_pattern=masking_pattern,
        )

        return output

  

def load_pretrained_esm(model_name='esmfold_3B_v1.pt', model_dir='checkpoint/esmfold'):
    
    if model_name.endswith(".pt"):  # local, treat as filepath
        model_path = os.path.join(model_dir, model_name)
        model_data = torch.load(str(model_path), map_location="cpu")
    else:  # load from hub
        url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
        model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")

    cfg = model_data["cfg"]["model"]
    model_state = model_data["model"]

    esm_location = os.path.join(model_dir, "esm2_t36_3B_UR50D.pt")
    model = ESM2_Extractor(esm_location, esmfold_config=cfg)

    expected_keys = set(model.state_dict().keys())
    found_keys = set(model_state.keys())

    missing_essential_keys = []
    for missing_key in expected_keys - found_keys:
        if not missing_key.startswith("esm."):
            missing_essential_keys.append(missing_key)

    if missing_essential_keys:
        logger.info(f"Keys '{', '.join(missing_essential_keys)}' are missing.")
        
    model.load_state_dict(model_state, strict=False)
    return model



def prepare_seq(seq_path, max_len=None):
    sequences = get_seqs(seq_path)
    dimer_sequences = {}
    for name in sequences: 
        seq_info = sequences[name]
        total_len = seq_info['chain1_len'] + seq_info['chain2_len']
        if max_len is not None and total > max_len:
            continue
        dimer_seq = seq_info['dimer_seq']
        dimer_sequences[name] = re.sub(',', ':', dimer_seq)
    return dimer_sequences



def extract_esm2_representation(model, data, data_mode, data_dir, linker_len=25, crop=True, crop_size=200, renew=False):

    data_dir = data_dir + 'esm2_3b_feats/'+str(linker_len)
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    if data_mode=='train' and crop:
        save_path = os.path.join(data_dir, data_mode+'_crop'+str(crop_size)+'_seq_rep.pickle')
    else:
        save_path = os.path.join(data_dir, data_mode+'_seq_rep.pickle')
   
    if not os.path.exists(save_path) or renew:
        linker = linker_len*'G'
        result = {}
        for name, seq in data.items():
            if type(seq) == dict:
                seq = ':'.join(seq['seqs'])
            seq_repr = model.infer(seq, chain_linker=linker)
            result[name] = seq_repr.cpu()
        with open(save_path, mode='wb') as f:
            pickle.dump(result, f) 
        logger.info('save to {}'.format(save_path))
    else:
        logger.info('{} already exists'.format(save_path))

    

if __name__ == '__main__':

    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_type', type=str, default='heterodimer', choices=['heterodimer'])
    parser.add_argument('--data_dir', type=str, default='data/')
    parser.add_argument('--data_mode', type=str, default='train', choices=['train', 'valid', 'test', 'test2', 'vhvl68', 'vhvl171'])
    parser.add_argument('--backbone_dir', type=str, default='checkpoint/esmfold', help='File directory for backbone model')
    parser.add_argument('--linker_len', type=int, default=25)
    parser.add_argument('--crop', default=False, action='store_true', help="whether to crop chains")
    parser.add_argument('--crop_size', type=int, default=200, help='multi-chain cropping, the train seq len')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--renew', default=False, action='store_true')
    args = parser.parse_args()

    model = load_pretrained_esm(model_name='esmfold_3B_v1.pt', model_dir=args.backbone_dir)
    model = model.eval().cuda()

    data_mode = args.data_mode

    if data_mode == 'train' and args.crop:

        logger.info('Perform cropping')
        seq_path = os.path.join(args.data_dir, data_mode + '_profiling.csv')
        dist_path = os.path.join(args.data_dir, data_mode + '_distance_map.pickle')
        crop(seq_path, 
             dist_path, 
             crop_size=args.crop_size, 
             spatial_crop_prob=0.5, 
             cb_cb_threshold=10, 
             seed=args.seed)
        
        logger.info('Precompute esm2 seq representation')
        seq_path = os.path.join(args.data_dir, 'train_crop'+str(crop_size)+'_profiling.pickle')
        seqs = prepare_seq(seq_path, max_len=None)
        extract_esm2_representation(model, seqs, data_mode, args.data_dir, args.linker_len, 
                                crop=args.crop, 
                                crop_size=args.crop_size,
                                renew=args.renew)
    else:
        logger.info('Precompute esm2 seq representation')
        seq_path = os.path.join(args.data_dir, data_mode+'_profiling.pickle')
        seqs = prepare_seq(seq_path, max_len=None)
        extract_esm2_representation(model, seqs, data_mode, args.data_dir, args.linker_len, 
                                    crop=False, 
                                    renew=args.renew)
    

